package com.abcode.st.websocket;
import com.abcode.st.websocket.common.ResponseMsg;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;

import javax.websocket.Session;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;

/**
 * session管理类
 *
 * @author qinzhitao
 * @date 2023/3/21
 */
@Component
public class CurrencySocketService {
    private final static Logger log = LoggerFactory.getLogger(CurrencySocketService.class);
    /**
     * 记录当前在线连接数
     */
    private static AtomicInteger onlineCount = new AtomicInteger(0);

    private ConcurrentHashMap<String, Session> sessionUserMap = new ConcurrentHashMap<>();

    /**
     * 注册和离开都需要这个锁
     */
    private final ReentrantLock lock = new ReentrantLock();

    public ConcurrentHashMap<String, Session> getSessionUserMap() {
        return sessionUserMap;
    }

    public CurrencySocketService( ) {

    }

    public void reg(Session session) throws FileNotFoundException, IOException {
        final ReentrantLock lock = this.lock;
        lock.lock();
        try {
            //在线人数+1
            onlineCount.incrementAndGet();
            String sessionId = session.getId();
            sessionUserMap.put(sessionId, session);
            log.info("|reg|创建连接|sessionId:{}|在线人数：{}", sessionId, onlineCount.get());
        } finally {
            lock.unlock();
        }
    }

    public void unReg(Session session) {
        final ReentrantLock lock = this.lock;
        lock.lock();
        try {
            if (session != null) {
                String sessionId= session.getId();
                if(sessionUserMap.containsKey(sessionId)){
                    sessionUserMap.remove(sessionId);
                    // 在线数减1
                    onlineCount.decrementAndGet();
                }
                log.info("|unReg|注销连接|sessionId:{}|在线人数：{}", sessionId, onlineCount.get());
            }
        } finally {
            lock.unlock();
        }
    }

    /**
     * 给所有连接发信息
     * @param msg
     * @return
     * @author qinzhitao
     */
    public void sendMsgToAll(ResponseMsg msg) {
        for (Session session : sessionUserMap.values()) {
            synchronized (session) {
                try {
                    if (session != null && session.isOpen()) {
                        session.getBasicRemote().sendText(msg.toString());
                    }
                } catch (Exception e) {
                    log.error("|sendMsgToAll|发生异常|sessionId：{}|{}", msg, e);
                }
            }
        }
    }
}
